{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gensim 文本处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import random\n",
    "import re\n",
    "import os\n",
    "from pprint import pprint\n",
    "\n",
    "from time import time\n",
    "\n",
    "import gensim\n",
    "from gensim.models import Word2Vec, Doc2Vec\n",
    "from gensim import utils\n",
    "\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "\n",
    "from sklearn.feature_extraction.stop_words import ENGLISH_STOP_WORDS\n",
    "\n",
    "from collections import defaultdict\n",
    "from smart_open import open\n",
    "from six import iteritems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt_path = 'https://radimrehurek.com/gensim/mycorpus.txt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 小数据处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文本\n",
    "list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocessor1(text):\n",
    "    text = re.sub('<[^>]*>', '', text)\n",
    "    text = re.sub('[\\W]+', ' ', text.lower())\n",
    "    # text = text.split()\n",
    "    return text\n",
    "token_review = [preprocessor1(i) for i in open(txt_path) ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'human machine interface for lab abc computer applications '"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_review[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 停用词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_stopwords = [] # 自定义停用词列表\n",
    "STOPWORDS = stopwords.words(\"english\") + list(ENGLISH_STOP_WORDS) + custom_stopwords\n",
    "STOPWORDS = list(set(STOPWORDS))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 分词 转小写 删除停用词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [[word for word in document.lower().split() if word not in STOPWORDS]\n",
    "         for document in token_review]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 词频计算\n",
    "\n",
    "留下出现次数大于X的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['human', 'interface', 'computer'],\n",
      " ['survey', 'user', 'computer', 'response', 'time'],\n",
      " ['eps', 'user', 'interface'],\n",
      " ['human', 'eps'],\n",
      " ['user', 'response', 'time'],\n",
      " ['trees'],\n",
      " ['graph', 'trees'],\n",
      " ['graph', 'minors', 'trees'],\n",
      " ['graph', 'minors', 'survey']]\n"
     ]
    }
   ],
   "source": [
    "frequency = defaultdict(int)\n",
    "for text in texts:\n",
    "    for token in text:\n",
    "        frequency[token] += 1\n",
    "\n",
    "X = 1\n",
    "processed_corpus = [[token for token in text if frequency[token] > X] for text in texts]\n",
    "pprint(processed_corpus)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 单词与整数ID的map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dictionary(11 unique tokens: ['computer', 'human', 'interface', 'response', 'survey']...)\n"
     ]
    }
   ],
   "source": [
    "dictionary = gensim.corpora.Dictionary(processed_corpus)\n",
    "dictionary.save('./idx2words_map.dict')\n",
    "print(dictionary)\n",
    "num_features = 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 2, 2: 2, 0: 2, 4: 2, 6: 3, 3: 2, 5: 2, 7: 2, 8: 3, 9: 3, 10: 2}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dictionary.dfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "给所有词指定唯一的id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'computer': 0,\n",
      " 'eps': 7,\n",
      " 'graph': 9,\n",
      " 'human': 1,\n",
      " 'interface': 2,\n",
      " 'minors': 10,\n",
      " 'response': 3,\n",
      " 'survey': 4,\n",
      " 'time': 5,\n",
      " 'trees': 8,\n",
      " 'user': 6}\n"
     ]
    }
   ],
   "source": [
    "pprint(dictionary.token2id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用doc2bow字典的方法为文档创建词袋表示法：\n",
    "\n",
    "每个元组中的第一个条目对应于字典中令牌的ID，第二个条目对应于此令牌的计数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[(0, 1), (1, 1), (2, 1)],\n",
      " [(0, 1), (3, 1), (4, 1), (5, 1), (6, 1)],\n",
      " [(2, 1), (6, 1), (7, 1)],\n",
      " [(1, 1), (7, 1)],\n",
      " [(3, 1), (5, 1), (6, 1)],\n",
      " [(8, 1)],\n",
      " [(8, 1), (9, 1)],\n",
      " [(8, 1), (9, 1), (10, 1)],\n",
      " [(4, 1), (9, 1), (10, 1)]]\n"
     ]
    }
   ],
   "source": [
    "bow_corpus = [dictionary.doc2bow(text) for text in processed_corpus]\n",
    "gensim.corpora.MmCorpus.serialize('./texts_bow.mm', bow_corpus)\n",
    "pprint(bow_corpus)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练tf-idf模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfidf = gensim.models.TfidfModel(bow_corpus)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该tfidf模型再次返回一个元组列表，其中第一个条目是令牌ID，第二个条目是tf-idf权重。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 0.5773502691896257), (1, 0.5773502691896257), (2, 0.5773502691896257)]\n"
     ]
    }
   ],
   "source": [
    "print(tfidf[bow_corpus[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 相似性查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 0.99999994), (1, 0.27115756), (2, 0.36272496), (3, 0.40824828), (4, 0.0), (5, 0.0), (6, 0.0), (7, 0.0), (8, 0.0)]\n"
     ]
    }
   ],
   "source": [
    "index = gensim.similarities.SparseMatrixSimilarity(tfidf[bow_corpus],num_features=num_features)\n",
    "query_bow = bow_corpus[0]\n",
    "sims = index[tfidf[query_bow]]\n",
    "print(list(enumerate(sims)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.99999994\n",
      "3 0.40824828\n",
      "2 0.36272496\n",
      "1 0.27115756\n",
      "4 0.0\n",
      "5 0.0\n",
      "6 0.0\n",
      "7 0.0\n",
      "8 0.0\n"
     ]
    }
   ],
   "source": [
    "for document_number, score in sorted(enumerate(sims), key=lambda x: x[1], reverse=True):\n",
    "    print(document_number, score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 大数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 语料流\n",
    "每行一个文本 词间以空格隔开\n",
    "\n",
    "语料库不加载到RAM，因为一次最多只有一个向量驻留在RAM中。语料库现在可以随您想要的大小而变大。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyCorpusRaw(object):\n",
    "    def __iter__(self):\n",
    "        for line in open(txt_path):\n",
    "            yield line.lower().split()\n",
    "\n",
    "corpus_raw_memory_friendly = MyCorpusRaw()\n",
    "dictionary_ = gensim.corpora.Dictionary(corpus_raw_memory_friendly)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_stopwords = [] # 自定义停用词列表\n",
    "STOPWORDS = stopwords.words(\"english\") + list(ENGLISH_STOP_WORDS) + custom_stopwords\n",
    "STOPWORDS = list(set(STOPWORDS))\n",
    "stop_ids = [\n",
    "    dictionary.token2id[stopword]\n",
    "    for stopword in STOPWORDS\n",
    "    if stopword in dictionary.token2id\n",
    "]\n",
    "once_ids = [tokenid for tokenid, docfreq in iteritems(dictionary_.dfs) if docfreq == 1]\n",
    "dictionary_.filter_tokens(stop_ids + once_ids)\n",
    "dictionary_.compactify()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 1), (1, 1), (2, 1)]\n",
      "[(0, 1), (3, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1)]\n",
      "[(2, 1), (7, 1), (9, 1), (10, 1), (11, 1)]\n",
      "[(1, 1), (4, 1), (7, 2), (10, 1), (12, 1)]\n",
      "[(4, 1), (5, 1), (8, 1), (9, 1)]\n",
      "[(4, 1), (11, 1), (13, 1)]\n",
      "[(4, 1), (11, 1), (13, 1), (14, 1)]\n",
      "[(4, 1), (12, 1), (13, 1), (14, 1), (15, 1)]\n",
      "[(3, 1), (6, 1), (14, 1), (15, 1)]\n"
     ]
    }
   ],
   "source": [
    "class MyCorpus(object):\n",
    "    def __iter__(self):\n",
    "        for line in open(txt_path):\n",
    "            yield dictionary_.doc2bow(line.lower().split())\n",
    "corpus_memory_friendly = MyCorpus()\n",
    "for vector in corpus_memory_friendly:\n",
    "    print(vector)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocessor2(text):\n",
    "    text = re.sub('<[^>]*>', '', text)\n",
    "    text = re.sub('[\\W]+', ' ', text.lower())\n",
    "    text = text.split()\n",
    "    return text\n",
    "token_review = [preprocessor2(i) for i in open(txt_path) ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_dim = 200\n",
    "model = gensim.models.Word2Vec(token_review, min_count=2, size=vector_dim)\n",
    "#model.train(sentences,total_examples=len(sentences),epochs=10)\n",
    "model.save('word2vectors.bin')\n",
    "model.wv.save_word2vec_format('word2vectors.txt', binary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相似词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('of', 0.10442149639129639),\n",
       " ('response', 0.0843486338853836),\n",
       " ('system', 0.05847841501235962),\n",
       " ('trees', 0.0459064245223999),\n",
       " ('survey', 0.04115157574415207),\n",
       " ('time', 0.030271658673882484)]"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.most_similar(positive=\"human\",topn=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('of', 0.12239014357328415),\n",
       " ('response', 0.07446271181106567),\n",
       " ('survey', 0.039067476987838745),\n",
       " ('graph', 0.02384733408689499),\n",
       " ('user', 0.0014887526631355286),\n",
       " ('system', -0.0035650990903377533),\n",
       " ('trees', -0.019153958186507225),\n",
       " ('a', -0.033130671828985214),\n",
       " ('minors', -0.04246136546134949),\n",
       " ('the', -0.05228729918599129)]"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w1 = [\"human\",'computer']\n",
    "w2 = ['time'] # 排除\n",
    "model.wv.most_similar(positive=w1,negative=w2,topn=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相似度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0036251114"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.similarity(w1=\"human\",w2=\"user\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "非同类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lzy/env/pt/lib/python3.7/site-packages/gensim/models/keyedvectors.py:877: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  vectors = vstack(self.word_vec(word, use_norm=True) for word in used_words).astype(REAL)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'human'"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.doesnt_match([\"human\",\"user\",\"trees\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "vector 向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2.3988045e-03,  1.2539937e-03,  6.9262099e-04,  4.9847364e-04,\n",
       "       -1.8678217e-03, -8.5406253e-05, -2.3130092e-03,  1.0084164e-03,\n",
       "       -8.6981483e-04,  1.5369804e-03,  2.4022781e-03, -2.1757684e-03,\n",
       "       -2.1571585e-04,  1.3833324e-03, -9.9270116e-04, -2.0988514e-03,\n",
       "       -2.1577813e-03,  2.4105872e-03,  8.2103058e-04,  2.0665555e-04,\n",
       "        1.4734926e-03, -1.7346102e-03,  4.8905431e-04,  6.4237748e-04,\n",
       "        1.8797844e-03,  1.0360393e-03,  8.0722343e-04, -2.3487874e-03,\n",
       "       -2.3470134e-03,  2.1651248e-03,  1.2582765e-03, -7.2150916e-04,\n",
       "        1.0646669e-03,  1.6872594e-03, -2.1409900e-03,  7.5836171e-04,\n",
       "       -6.7712233e-04, -1.7152717e-03,  4.4151231e-05, -1.0926957e-03,\n",
       "        1.7349893e-03, -1.4390079e-03, -1.6277698e-03,  2.3250822e-03,\n",
       "        1.5569222e-03,  1.3235764e-03, -2.0074495e-03, -1.8302349e-03,\n",
       "       -2.0173925e-03, -1.4989993e-03,  7.2281488e-04, -6.9727673e-04,\n",
       "        9.7435206e-04, -1.5983926e-04, -2.0526850e-03,  1.5874363e-03,\n",
       "       -2.3918210e-03, -5.8608560e-04,  5.3761684e-04, -9.1485266e-04,\n",
       "        2.1672316e-03,  1.2641036e-03,  1.0099575e-03,  2.1097695e-03,\n",
       "       -1.7723533e-03,  7.8461302e-04,  2.1815652e-03, -3.3093467e-05,\n",
       "       -1.9342861e-04, -1.7271067e-03, -1.8816154e-03,  7.9081423e-04,\n",
       "       -1.8019675e-03, -1.1574307e-03,  3.8036393e-04,  2.4716693e-03,\n",
       "        2.4870746e-03,  1.8114661e-03, -3.8601164e-04, -6.5278047e-04,\n",
       "       -2.4311056e-03, -2.2802572e-03, -1.5421491e-03,  3.5922005e-04,\n",
       "       -6.5251091e-04,  1.8837660e-03,  1.9234738e-03,  3.3269770e-04,\n",
       "        2.1695599e-03, -2.5759215e-04,  6.2802486e-04,  1.3322606e-03,\n",
       "       -2.1524611e-03,  1.7699624e-03,  4.6450255e-04, -1.6834307e-03,\n",
       "       -1.5465614e-03, -2.0237842e-03, -2.2742320e-03,  1.5039262e-03,\n",
       "       -6.5247441e-04, -1.4317570e-03, -1.4992742e-03, -5.5570039e-04,\n",
       "        2.2312475e-03,  1.2884900e-03, -1.5699820e-03,  5.8275671e-04,\n",
       "        1.3170377e-03,  1.6010281e-03, -7.0003822e-04,  2.0145180e-03,\n",
       "        8.4862266e-05, -1.2002197e-03,  7.7350193e-04, -1.9482737e-03,\n",
       "        2.0212661e-03,  3.6307058e-05,  1.3202550e-03, -2.3309694e-04,\n",
       "        1.4449540e-03,  1.2700385e-03,  9.6885889e-04,  2.2402783e-03,\n",
       "       -1.8298622e-03,  2.2558616e-03, -1.4717102e-03,  9.9012999e-05,\n",
       "        1.1589180e-03, -2.6858345e-04,  1.3452298e-04,  1.1816638e-03,\n",
       "       -7.3201826e-04, -2.1017746e-03, -1.7845734e-03, -1.7626407e-04,\n",
       "       -2.0131881e-03, -1.8699565e-03, -9.2089045e-05,  2.1329701e-03,\n",
       "        5.7339936e-04, -7.1315008e-04, -5.8188243e-04,  9.2584436e-04,\n",
       "        2.0128896e-03,  2.4097490e-03,  2.1670305e-04, -1.7942565e-03,\n",
       "       -2.6200199e-04,  1.0202839e-03,  1.1144418e-03,  6.3139305e-04,\n",
       "       -4.2817480e-04,  1.5332211e-03, -1.4675662e-03,  2.5194566e-04,\n",
       "       -1.5720694e-03, -3.4617589e-04,  2.0052905e-03, -9.4317924e-04,\n",
       "       -1.7889714e-03, -1.6965104e-03,  1.5361219e-03,  1.3715095e-03,\n",
       "        1.8049758e-03, -2.4843479e-03,  1.2660018e-03, -1.8439008e-03,\n",
       "        8.4767124e-04,  3.1367698e-04,  2.0403268e-03,  2.3827222e-03,\n",
       "        3.3278548e-04,  2.0835248e-03, -3.1300515e-04, -1.3001665e-03,\n",
       "       -4.5671687e-04,  5.2194239e-04, -2.0564413e-03, -2.4471085e-03,\n",
       "       -7.8056636e-04,  6.1602442e-04,  1.3014206e-03, -8.7890209e-04,\n",
       "       -1.9402981e-03,  2.1903638e-03, -3.5908431e-04, -6.9569389e-04,\n",
       "       -8.1450981e-04,  1.1840992e-03, -2.8985229e-04, -1.1280138e-03,\n",
       "       -1.2541622e-03, -6.7020050e-04,  8.7372627e-04, -9.6441869e-04,\n",
       "        3.8043261e-05, -9.6013653e-05, -2.3991407e-03,  7.4033230e-04],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv[\"trees\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 /5 epochs\n",
      "2 /5 epochs\n",
      "3 /5 epochs\n",
      "4 /5 epochs\n",
      "5 /5 epochs\n",
      "Time taken for Word2vec training:  0.0003178755442301432  (mins).\n",
      "similar words of 'human' [('computer', 0.10129371285438538), ('survey', 0.0811525285243988), ('a', 0.07710925489664078), ('system', 0.07218603044748306), ('of', 0.07031609117984772), ('and', 0.06266631186008453), ('user', 0.061590321362018585), ('trees', -0.004959670826792717), ('response', -0.035101912915706635), ('graph', -0.05454497039318085)]\n",
      "1 /5 epochs\n",
      "2 /5 epochs\n",
      "3 /5 epochs\n",
      "4 /5 epochs\n",
      "5 /5 epochs\n",
      "Time taken for Word2vec training:  0.0003329038619995117  (mins).\n",
      "similar words of 'human' [('response', 0.07876017689704895), ('survey', 0.03825037181377411), ('of', 0.00648870412260294), ('computer', 0.0046798065304756165), ('and', 0.004269154742360115), ('user', -0.0034093516878783703), ('system', -0.016295716166496277), ('minors', -0.016591623425483704), ('time', -0.03170005977153778), ('eps', -0.032730832695961)]\n"
     ]
    }
   ],
   "source": [
    "dims = [100, 600]\n",
    "for size in dims:\n",
    "    #instantiate our  model\n",
    "    model_w2v = Word2Vec(min_count=2, window=5, size=size, sample=1e-3, negative=5, workers=4, sg=0)\n",
    "\n",
    "    #build vocab over all reviews\n",
    "    model_w2v.build_vocab(token_review)\n",
    "\n",
    "    #We pass through the data set multiple times, shuffling the training reviews each time to improve accuracy.\n",
    "    Idx=list(range(len(token_review)))\n",
    "\n",
    "    t0 = time()\n",
    "    for epoch in range(5):\n",
    "         print(epoch+1, \"/5 epochs\")\n",
    "         random.shuffle(Idx)\n",
    "         perm_sentences = [token_review[i] for i in Idx]\n",
    "         model_w2v.train(perm_sentences,total_examples=len(Idx), epochs = 1)\n",
    "\n",
    "    elapsed=time() - t0\n",
    "    print(\"Time taken for Word2vec training: \", elapsed/60, \" (mins).\")\n",
    "\n",
    "    # saves the word2vec model to be used later.\n",
    "    #model_w2v.save('./model_word2vec_skipgram_300dim')\n",
    "\n",
    "    # open a saved word2vec model\n",
    "    #model_w2v=gensim.models.Word2Vec.load('./model_word2vec')\n",
    "\n",
    "    model_w2v.wv.save_word2vec_format('./data/model_word2vec_v2_%ddim.txt'%size, binary=False)\n",
    "    print(\"similar words of 'human'\", model_w2v.wv.most_similar('human') )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt",
   "language": "python",
   "name": "pt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
